package com.gzxx.mongo.impl;

import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.gzxx.mongo.MongoSqlLogServer;
import com.gzxx.mongo.utils.FormatUtils;
import org.bson.Document;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.mongodb.core.convert.MongoConverter;
import org.springframework.data.mongodb.core.convert.QueryMapper;
import org.springframework.data.mongodb.core.convert.UpdateMapper;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.util.List;

/**
 * @Description TODO
 * @Author pengpdx
 * @Date 2020/7/22 18:43
 */
@Service
public class MongoSqlLogServerImpl implements MongoSqlLogServer {

    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    /**
     * 是否打印SQL信息
     */
    @Value("${spring.data.mongodb.print:false}")
    private Boolean PRINT_SQL;

    @Autowired
    private MongoConverter mongoConverter;


    private QueryMapper queryMapper;

    private UpdateMapper updateMapper;

    @PostConstruct
    public void init() {
        queryMapper = new QueryMapper(mongoConverter);
        updateMapper = new UpdateMapper(mongoConverter);
    }

    /**
     * 打印查询语句
     *
     * @param query
     */
    @Override
    public void logQuery(Class<?> clazz, Query query) {
        if (PRINT_SQL) {
            MongoPersistentEntity<?> entity = mongoConverter.getMappingContext().getPersistentEntity(clazz);
            Document mappedQuery = queryMapper.getMappedObject(query.getQueryObject(), entity);
            Document mappedField = queryMapper.getMappedObject(query.getFieldsObject(), entity);
            Document mappedSort = queryMapper.getMappedObject(query.getSortObject(), entity);

            String log = "\ndb." + StrUtil.lowerFirst(clazz.getSimpleName()) + ".find(";

            log += FormatUtils.bson(mappedQuery.toJson()) + ")";

            if (!query.getFieldsObject().isEmpty()) {
                log += ".projection(";
                log += FormatUtils.bson(mappedField.toJson()) + ")";
            }

            if (query.isSorted()) {
                log += ".sort(";
                log += FormatUtils.bson(mappedSort.toJson()) + ")";
            }

            if (query.getLimit() != 0l) {
                log += ".limit(" + query.getLimit() + ")";
            }

            if (query.getSkip() != 0l) {
                log += ".skip(" + query.getSkip() + ")";
            }
            log += ";";

            logger.info(log);
        }
    }

    /**
     * 打印查询语句
     *
     * @param query
     */
    @Override
    public void logCount(Class<?> clazz, Query query) {
        if (PRINT_SQL) {
            MongoPersistentEntity<?> entity = mongoConverter.getMappingContext().getPersistentEntity(clazz);
            Document mappedQuery = queryMapper.getMappedObject(query.getQueryObject(), entity);

            String log = "\ndb." + StrUtil.lowerFirst(clazz.getSimpleName()) + ".find(";
            log += FormatUtils.bson(mappedQuery.toJson()) + ")";
            log += ".count();";

            logger.info(log);
        }
    }

    /**
     * 打印查询语句
     *
     * @param query
     */
    @Override
    public void logDelete(Class<?> clazz, Query query) {
        if (PRINT_SQL) {
            MongoPersistentEntity<?> entity = mongoConverter.getMappingContext().getPersistentEntity(clazz);
            Document mappedQuery = queryMapper.getMappedObject(query.getQueryObject(), entity);

            String log = "\ndb." + StrUtil.lowerFirst(clazz.getSimpleName()) + ".remove(";
            log += FormatUtils.bson(mappedQuery.toJson()) + ")";
            log += ";";
            logger.info(log);
        }
    }

    /**
     * 打印查询语句
     *
     * @param query
     */
    @Override
    public void logUpdate(Class<?> clazz, Query query, Update update, boolean multi) {
        if (PRINT_SQL) {
            MongoPersistentEntity<?> entity = mongoConverter.getMappingContext().getPersistentEntity(clazz);
            Document mappedQuery = queryMapper.getMappedObject(query.getQueryObject(), entity);
            Document mappedUpdate = updateMapper.getMappedObject(update.getUpdateObject(), entity);

            String log = "\ndb." + StrUtil.lowerFirst(clazz.getSimpleName()) + ".update(";
            log += FormatUtils.bson(mappedQuery.toJson()) + ",";
            log += FormatUtils.bson(mappedUpdate.toJson()) + ",";
            log += FormatUtils.bson("{multi:" + multi + "})");
            log += ";";
            logger.info(log);
        }

    }

    /**
     * 打印查询语句
     *
     * @param object
     */
    @Override
    public void logSave(Object object) {
        if (PRINT_SQL) {
            String log = "\ndb." + StrUtil.lowerFirst(object.getClass().getSimpleName()) + ".save(";
            log += JSONUtil.toJsonPrettyStr(object);
            log += ");";
            logger.info(log);
        }
    }

    /**
     * 打印查询语句
     *
     * @param list
     */
    @Override
    public void logSave(List<?> list) {
        if (PRINT_SQL && list.size() > 0) {
            Object object = list.get(0);

            String log = "\ndb." + StrUtil.lowerFirst(object.getClass().getSimpleName()) + ".save(";
            log += JSONUtil.toJsonPrettyStr(list);
            log += ");";
            logger.info(log);
        }
    }
}
